/*
 * Copyright (c) 2018, apexes.net. All rights reserved.
 *
 *         http://www.apexes.net
 *
 */
package net.apexes.commons.guice.tx;

import java.sql.Connection;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

/**
 * @author <a href="mailto:hedyn@foxmail.com">HeDYn</a>
 */
public class Txs {
    private Txs() {
    }

    private static final ThreadLocal<LocalBean> THREAD_LOCAL = new ThreadLocal<>();

    private static class LocalBean {
        final DataSourceHolder dataSourceHolder;
        final List<Runnable> afterCommitList = new ArrayList<>();
        final List<Runnable> afterRollbackList = new ArrayList<>();

        LocalBean(Tx tx) {
            this.dataSourceHolder = new DataSourceHolder(tx);
        }

        void commit() throws SQLException {
            dataSourceHolder.commit();
            for (Runnable runnable : afterCommitList) {
                runnable.run();
            }
        }

        void rollback() throws SQLException {
            dataSourceHolder.rollback();
            for (Runnable runnable : afterRollbackList) {
                runnable.run();
            }
        }
    }

    /**
     * @deprecated 使用 {@link #addAfterTxCommit(Runnable)}
     */
    @Deprecated
    public static void addAfterCommit(Runnable runnable) {
        addAfterTxCommit(runnable);
    }

    /**
     * @deprecated 使用 {@link #addAfterTxRollback(Runnable)}
     */
    @Deprecated
    public static void addAfterRollback(Runnable runnable) {
        addAfterTxRollback(runnable);
    }

    public static void addAfterTxCommit(Runnable runnable) {
        if (runnable == null) {
            throw new NullPointerException();
        }
        LocalBean bean = THREAD_LOCAL.get();
        if (bean != null) {
            bean.afterCommitList.add(runnable);
        }
    }

    public static void addAfterTxRollback(Runnable runnable) {
        if (runnable == null) {
            throw new NullPointerException();
        }
        LocalBean bean = THREAD_LOCAL.get();
        if (bean != null) {
            bean.afterRollbackList.add(runnable);
        }
    }

    public static boolean isWithinTx() {
        return THREAD_LOCAL.get() != null;
    }

    static void begin(Tx tx) {
        THREAD_LOCAL.set(new LocalBean(tx));
    }

    static void end() throws SQLException {
        LocalBean bean = THREAD_LOCAL.get();
        if (bean != null) {
            try {
                bean.dataSourceHolder.close();
            } finally {
                THREAD_LOCAL.remove();
            }
        }
    }

    static void commit() throws SQLException {
        LocalBean bean = THREAD_LOCAL.get();
        if (bean != null) {
            bean.commit();
        }
    }

    static void rollback() throws SQLException {
        LocalBean bean = THREAD_LOCAL.get();
        if (bean != null) {
            bean.rollback();
        }
    }

    static Connection getConnection(TxDataSource dataSource) throws SQLException {
        LocalBean bean = THREAD_LOCAL.get();
        if (bean != null) {
            if (bean.dataSourceHolder.validate(dataSource)) {
                if (bean.dataSourceHolder.dataSource == null) {
                    bean.dataSourceHolder.setupConnection(dataSource);
                }
                return bean.dataSourceHolder.connection;
            }
        }

        return dataSource.getWithoutTxConnection();
    }

    static Connection getConnection(TxDataSource dataSource, String username, String password) throws SQLException {
        LocalBean bean = THREAD_LOCAL.get();
        if (bean != null) {
            if (bean.dataSourceHolder.validate(dataSource)) {
                if (bean.dataSourceHolder.dataSource == null) {
                    bean.dataSourceHolder.setupConnection(dataSource, username, password);
                }
                return bean.dataSourceHolder.connection;
            }
        }
        return dataSource.getWithoutTxConnection(username, password);
    }

    /**
     *
     * @author <a href="mailto:hedyn@foxmail.com">HeDYn</a>
     */
    private static class DataSourceHolder {

        private final Tx tx;
        private TxDataSource dataSource;
        private TxConnection connection;

        private DataSourceHolder(Tx tx) {
            this.tx = tx;
        }

        private boolean validate(TxDataSource ds) throws SQLException {
            if (tx.value().isEmpty()) {
                if (dataSource != null && !Objects.equals(dataSource.getName(), ds.getName())) {
                    // 在一个事务中有使用其他数据源时，Tx注解必须指定数据源名称
                    throw new SQLException("the @Tx value must be set in a transaction if multiple data sources.");
                }
                return true;
            }
            // Tx注解上设置了名称，只有当即将要获取连接的数据源名称与注解名称一致时该数据源才是有事务的
            return Objects.equals(tx.value(), ds.getName());
        }

        private void setupConnection(TxDataSource dataSource) throws SQLException {
            this.dataSource = dataSource;
            this.connection = dataSource.createTxConnection();
            configureConnection(connection, tx);
        }

        private void setupConnection(TxDataSource dataSource, String username, String password) throws SQLException {
            this.dataSource = dataSource;
            this.connection = dataSource.createTxConnection(username, password);
            configureConnection(connection, tx);
        }

        private void commit() throws SQLException {
            if (connection != null) {
                connection.commit();
            }
        }

        private void rollback() throws SQLException {
            if (connection != null) {
                connection.rollback();
            }
        }

        private void close() throws SQLException {
            TxConnection tmp = connection;
            this.connection = null;
            this.dataSource = null;
            if (tmp != null) {
                tmp.reallyClose();
            }
        }

        private static void configureConnection(TxConnection conn, Tx tx) throws SQLException {
            conn.setReadOnly(tx.readOnly());
            switch (tx.isolation()) {
                case READ_UNCOMMITTED:
                    conn.setTransactionIsolation(Connection.TRANSACTION_READ_UNCOMMITTED);
                    break;
                case READ_COMMITTED:
                    conn.setTransactionIsolation(Connection.TRANSACTION_READ_COMMITTED);
                    break;
                case REPEATABLE_READ:
                    conn.setTransactionIsolation(Connection.TRANSACTION_REPEATABLE_READ);
                    break;
                case SERIALIZABLE:
                    conn.setTransactionIsolation(Connection.TRANSACTION_SERIALIZABLE);
                    break;
                default:
                    break;
            }
        }

    }

}
